#!/usr/bin/env python
"""Generates wnids using class names for torchvision dataset"""

import argparse
import torchvision
from nbdt import data
from nltk.corpus import wordnet as wn
from pathlib import Path
from nbdt.utils import Colors, generate_kwargs
from nbdt.thirdparty.wn import (
    maybe_install_wordnet,
    synset_to_wnid,
    write_wnids,
    FakeSynset,
)
import os

maybe_install_wordnet()

datasets = (
    ("CIFAR10", "CIFAR100", "Cityscapes")
    + data.imagenet.names
    + data.custom.names
    + data.pascal_context.names
    + data.lip.names
    + data.ade20k.names
)


parser = argparse.ArgumentParser()
parser.add_argument("--dataset", choices=datasets, default="CIFAR10")
parser.add_argument("--root", default="./nbdt/wnids")
parser.add_argument(
    "--classes",
    type=str,
    nargs="*",
    help="INSTEAD of writing WNIDs for a dataset, convert JUST"
    " this class name to a WNID.",
)
data.custom.add_arguments(parser)
args = parser.parse_args()

if args.classes:
    classes = args.classes
else:
    dataset = getattr(data, args.dataset)
    dataset_kwargs = generate_kwargs(
        args,
        dataset,
        name=f"Dataset {args.dataset}",
        keys=data.custom.keys,
        globals=globals(),
    )
    if args.dataset not in ["Cityscapes", "PascalContext", "LookIntoPerson", "ADE20K"]:
        dataset_kwargs["download"] = True

    dataset = dataset(**dataset_kwargs, root="./data")

    classes = dataset.classes
    if args.dataset == "Cityscapes":
        classes = [cls.name for cls in dataset.classes if not cls.ignore_in_eval]
    if args.dataset == "PascalContext":
        classes = [cls for cls in dataset.classes if cls != "background"]

path = Path(os.path.join(args.root, f"{args.dataset}.txt"))
os.makedirs(path.parent, exist_ok=True)
failures = []

hardcoded_mapping = {
    "aquarium_fish": wn.synsets("fingerling", pos=wn.NOUN)[0],
    "arcade_machine": wn.synsets("slot_machine", pos=wn.NOUN)[0],
    "background": wn.synsets("background", pos=wn.NOUN)[1],
    "barrel": wn.synsets("barrel", pos=wn.NOUN)[1],
    "beaver": wn.synsets("beaver", pos=wn.NOUN)[-1],
    "booth": wn.synsets("booth", pos=wn.NOUN)[1],
    "blind": wn.synsets("blind", pos=wn.NOUN)[2],
    "bulletin_board": wn.synsets("bulletin_board", pos=wn.NOUN)[1],
    "canopy": wn.synsets("canopy", pos=wn.NOUN)[2],
    "case": wn.synsets("case", pos=wn.NOUN)[-1],
    "castle": wn.synsets("castle", pos=wn.NOUN)[1],
    "column": wn.synsets("column", pos=wn.NOUN)[5],
    "cushion": wn.synsets("cushion", pos=wn.NOUN)[2],
    "diningtable": wn.synsets("dining_table", pos=wn.NOUN)[0],
    "earth": wn.synsets("earth", pos=wn.NOUN)[1],
    "escalator": wn.synsets("escalator", pos=wn.NOUN)[1],
    "flatfish": wn.synsets("flatfish", pos=wn.NOUN)[1],
    "food": wn.synsets("food", pos=wn.NOUN)[1],
    "glove": wn.synsets("glove", pos=wn.NOUN)[1],
    "grandstand": wn.synsets("grandstand", pos=wn.NOUN)[1],
    "lamp": wn.synsets("lamp", pos=wn.NOUN)[1],
    "land": wn.synsets("land", pos=wn.NOUN)[1],
    "leopard": wn.synsets("leopard", pos=wn.NOUN)[1],
    "left-arm": wn.synsets("arm", pos=wn.NOUN)[0],
    "left-leg": wn.synsets("leg", pos=wn.NOUN)[0],
    "left-shoe": wn.synsets("shoe", pos=wn.NOUN)[0],
    "lobster": wn.synsets("lobster", pos=wn.NOUN)[1],
    "maple_tree": wn.synsets("maple", pos=wn.NOUN)[1],
    "microwave": wn.synsets("microwave", pos=wn.NOUN)[1],
    "monitor": wn.synsets("monitor", pos=wn.NOUN)[3],
    "otter": wn.synsets("otter", pos=wn.NOUN)[1],
    "ottoman": wn.synsets("ottoman", pos=wn.NOUN)[2],
    "path": wn.synsets("path", pos=wn.NOUN)[2],
    "plant": wn.synsets("plant", pos=wn.NOUN)[1],
    "plate": wn.synsets("plate", pos=wn.NOUN)[3],
    "pottedplant": wn.synsets("plant", pos=wn.NOUN)[1],
    "raccoon": wn.synsets("raccoon", pos=wn.NOUN)[1],
    "radiator": wn.synsets("radiator", pos=wn.NOUN)[1],
    "ray": wn.synsets("ray", pos=wn.NOUN)[-1],
    "rider": wn.synsets("rider", pos=wn.NOUN)[2],
    "runway": wn.synsets("runway", pos=wn.NOUN)[3],
    "seal": wn.synsets("seal", pos=wn.NOUN)[-1],
    "shrew": wn.synsets("shrew", pos=wn.NOUN)[1],
    "sign": wn.synsets("sign", pos=wn.NOUN)[1],
    "skunk": wn.synsets("skunk", pos=wn.NOUN)[1],
    "stage": wn.synsets("stage", pos=wn.NOUN)[2],
    "step": wn.synsets("step", pos=wn.NOUN)[3],
    "table": wn.synsets("table", pos=wn.NOUN)[1],
    "tiger": wn.synsets("tiger", pos=wn.NOUN)[1],
    "toilet": wn.synsets("toilet", pos=wn.NOUN)[1],
    "traffic_sign": wn.synsets("street_sign", pos=wn.NOUN)[0],
    "turtle": wn.synsets("turtle", pos=wn.NOUN)[1],
    "tvmonitor": wn.synsets("tv_monitor", pos=wn.NOUN)[0],
    "upper-clothes": wn.synsets("top", pos=wn.NOUN)[9],
    "van": wn.synsets("van", pos=wn.NOUN)[-1],
    "washer": wn.synsets("washer", pos=wn.NOUN)[2],
    "water": wn.synsets("water", pos=wn.NOUN)[1],
    "whale": wn.synsets("whale", pos=wn.NOUN)[1],
}

wnids = []
for i, cls in enumerate(classes):
    if cls in hardcoded_mapping:
        synset = hardcoded_mapping[cls]
    else:
        synsets = wn.synsets(cls, pos=wn.NOUN)
        if not synsets:
            Colors.red(f"==> Failed to find synset for {cls}. Using fake synset...")
            failures.append(cls)
            synsets = [FakeSynset.create_from_offset(i)]
        synset = synsets[0]
    wnid = synset_to_wnid(synset)
    print(f"{wnid}: ({cls}) {synset.definition()}")
    wnids.append(wnid)

if not args.classes:
    write_wnids(wnids, path)
    Colors.green(f"==> Wrote to {path}")

if failures:
    Colors.red(f"==> Warning: failed to find wordnet IDs for {failures}")
